from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple, Union

import numpy as np
from scipy.signal import fftconvolve


__all__ = [
    "PoissonCfg",
    "compute_phi_from_mask",
    "radial_dist",
]


# -----------------------------
# Config
# -----------------------------
@dataclass
class PoissonCfg:
    # Softening length: accepts numbers (e.g., 1.0) or "c*sigma" (e.g., "1.5*sigma")
    epsilon_soften: Union[str, float] = "1.0*sigma"

    # Background subtraction: "none" or "ring"
    background_subtract: str = "none"
    ring_inner_fracL: float = 0.35
    ring_outer_fracL: float = 0.45


# -----------------------------
# Helpers
# -----------------------------
def _parse_eps(expr: Union[str, float], sigma_max: float) -> float:
    """Parse epsilon_soften, allowing strings like '1.5*sigma'."""
    if isinstance(expr, (int, float)):
        return float(expr)
    s = str(expr).strip().lower()
    if "sigma" in s:
        # Expect "<coef>*sigma" with optional spaces
        s = s.replace(" ", "")
        if "*sigma" in s:
            coef = s.split("*sigma")[0]
            return float(coef) * float(sigma_max)
        if s == "sigma":
            return float(sigma_max)
        # fallthrough: last token is sigma, first token is coefficient
    # Plain float string
    return float(s)


def radial_dist(shape: Tuple[int, int], cx: float, cy: float) -> np.ndarray:
    """Euclidean distance array from (cx, cy) in pixels."""
    h, w = shape
    yy, xx = np.indices((h, w))
    return np.hypot(xx - cx, yy - cy)


def _softened_kernel(L: int, eps: float) -> np.ndarray:
    """
    Build a (2L+1)x(2L+1) softened 1/r kernel for *linear* convolution.
    Normalized to unit sum (amplitude normalization doesn’t affect slopes).
    """
    R = int(2 * L + 1)
    y, x = np.ogrid[-L:L + 1, -L:L + 1]
    r2 = x * x + y * y
    ker = 1.0 / np.sqrt(r2 + (eps * eps))
    ker /= ker.sum()
    return ker


# -----------------------------
# Main
# -----------------------------
def compute_phi_from_mask(mask: np.ndarray, sigma_max: float, cfg: PoissonCfg) -> Tuple[np.ndarray, np.ndarray]:
    """
    Solve φ for a binary source mask by zero-padded (aperiodic) convolution with
    softened 1/r, optionally subtracting a ring background. Returns (phi, |∇φ|).
    """
    assert mask.ndim == 2 and mask.shape[0] == mask.shape[1], "mask must be (L,L)"
    L = mask.shape[0]

    eps = _parse_eps(cfg.epsilon_soften, float(sigma_max))
    ker = _softened_kernel(L=L, eps=eps)

    src = mask.astype(np.float64, copy=False)
    phi_raw = fftconvolve(src, ker, mode="same")  # linear conv; zero-padded boundaries

    if cfg.background_subtract.lower() == "ring":
        r = radial_dist(src.shape, cx=(L - 1) / 2.0, cy=(L - 1) / 2.0)
        rin = float(cfg.ring_inner_fracL) * L
        rout = float(cfg.ring_outer_fracL) * L
        ring_mask = (r >= rin) & (r <= rout)
        if np.any(ring_mask):
            ring_mean = float(np.mean(phi_raw[ring_mask]))
        else:
            ring_mean = 0.0
        phi = phi_raw - ring_mean   # subtract mean level taken from the ring
    else:
        phi = phi_raw

    # gradient magnitude
    gy, gx = np.gradient(phi)
    gnorm = np.hypot(gx, gy)

    return phi, gnorm
